# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License

"""Utilities for converting TFHub BigGAN generator weights to PyTorch.

Recommended usage:

To convert all BigGAN variants and generate test samples, use:

```bash
CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples
```

See `parse_args` for additional options.
"""

import argparse
import os
import sys

import h5py
import torch
import torch.nn as nn
from torchvision.utils import save_image
import tensorflow as tf
import tensorflow_hub as hub
import parse

# import reference biggan from this folder
import biggan_v1 as biggan_for_conversion

# Import model from main folder
sys.path.append("..")
import BigGAN


DEVICE = "cuda"
HDF5_TMPL = "biggan-{}.h5"
PTH_TMPL = "biggan-{}.pth"
MODULE_PATH_TMPL = "https://tfhub.dev/deepmind/biggan-{}/2"
Z_DIMS = {128: 120, 256: 140, 512: 128}
RESOLUTIONS = list(Z_DIMS)


def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False):
    """Loads TFHub weights and saves them to intermediate HDF5 file.

  Args:
    module_path ([Path-like]): Path to TFHub module.
    hdf5_path ([Path-like]): Path to output HDF5 file.

  Returns:
    [h5py.File]: Loaded hdf5 file containing module weights.
  """
    if os.path.exists(hdf5_path) and (not redownload):
        print("Loading BigGAN hdf5 file from:", hdf5_path)
        return h5py.File(hdf5_path, "r")

    print("Loading BigGAN module from:", module_path)
    tf.reset_default_graph()
    hub.Module(module_path)
    print("Loaded BigGAN module from:", module_path)

    initializer = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(initializer)

    print("Saving BigGAN weights to :", hdf5_path)
    h5f = h5py.File(hdf5_path, "w")
    for var in tf.global_variables():
        val = sess.run(var)
        h5f.create_dataset(var.name, data=val)
        print(f"Saving {var.name} with shape {val.shape}")
    h5f.close()
    return h5py.File(hdf5_path, "r")


class TFHub2Pytorch(object):

    TF_ROOT = "module"

    NUM_GBLOCK = {128: 5, 256: 6, 512: 7}

    w = "w"
    b = "b"
    u = "u0"
    v = "u1"
    gamma = "gamma"
    beta = "beta"

    def __init__(
        self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False
    ):
        self.state_dict = state_dict
        self.tf_weights = tf_weights
        self.resolution = resolution
        self.verbose = verbose
        if load_ema:
            for name in ["w", "b", "gamma", "beta"]:
                setattr(self, name, getattr(self, name) + "/ema_b999900")

    def load(self):
        self.load_generator()
        return self.state_dict

    def load_generator(self):
        GENERATOR_ROOT = os.path.join(self.TF_ROOT, "Generator")

        for i in range(self.NUM_GBLOCK[self.resolution]):
            name_tf = os.path.join(GENERATOR_ROOT, "GBlock")
            name_tf += f"_{i}" if i != 0 else ""
            self.load_GBlock(f"GBlock.{i}.", name_tf)

        self.load_attention("attention.", os.path.join(GENERATOR_ROOT, "attention"))
        self.load_linear("linear", os.path.join(self.TF_ROOT, "linear"), bias=False)
        self.load_snlinear("G_linear", os.path.join(GENERATOR_ROOT, "G_Z", "G_linear"))
        self.load_colorize("colorize", os.path.join(GENERATOR_ROOT, "conv_2d"))
        self.load_ScaledCrossReplicaBNs(
            "ScaledCrossReplicaBN", os.path.join(GENERATOR_ROOT, "ScaledCrossReplicaBN")
        )

    def load_linear(self, name_pth, name_tf, bias=True):
        self.state_dict[name_pth + ".weight"] = self.load_tf_tensor(
            name_tf, self.w
        ).permute(1, 0)
        if bias:
            self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(name_tf, self.b)

    def load_snlinear(self, name_pth, name_tf, bias=True):
        self.state_dict[name_pth + ".module.weight_u"] = self.load_tf_tensor(
            name_tf, self.u
        ).squeeze()
        self.state_dict[name_pth + ".module.weight_v"] = self.load_tf_tensor(
            name_tf, self.v
        ).squeeze()
        self.state_dict[name_pth + ".module.weight_bar"] = self.load_tf_tensor(
            name_tf, self.w
        ).permute(1, 0)
        if bias:
            self.state_dict[name_pth + ".module.bias"] = self.load_tf_tensor(
                name_tf, self.b
            )

    def load_colorize(self, name_pth, name_tf):
        self.load_snconv(name_pth, name_tf)

    def load_GBlock(self, name_pth, name_tf):
        self.load_convs(name_pth, name_tf)
        self.load_HyperBNs(name_pth, name_tf)

    def load_convs(self, name_pth, name_tf):
        self.load_snconv(name_pth + "conv0", os.path.join(name_tf, "conv0"))
        self.load_snconv(name_pth + "conv1", os.path.join(name_tf, "conv1"))
        self.load_snconv(name_pth + "conv_sc", os.path.join(name_tf, "conv_sc"))

    def load_snconv(self, name_pth, name_tf, bias=True):
        if self.verbose:
            print(f"loading: {name_pth} from {name_tf}")
        self.state_dict[name_pth + ".module.weight_u"] = self.load_tf_tensor(
            name_tf, self.u
        ).squeeze()
        self.state_dict[name_pth + ".module.weight_v"] = self.load_tf_tensor(
            name_tf, self.v
        ).squeeze()
        self.state_dict[name_pth + ".module.weight_bar"] = self.load_tf_tensor(
            name_tf, self.w
        ).permute(3, 2, 0, 1)
        if bias:
            self.state_dict[name_pth + ".module.bias"] = self.load_tf_tensor(
                name_tf, self.b
            ).squeeze()

    def load_conv(self, name_pth, name_tf, bias=True):

        self.state_dict[name_pth + ".weight_u"] = self.load_tf_tensor(
            name_tf, self.u
        ).squeeze()
        self.state_dict[name_pth + ".weight_v"] = self.load_tf_tensor(
            name_tf, self.v
        ).squeeze()
        self.state_dict[name_pth + ".weight_bar"] = self.load_tf_tensor(
            name_tf, self.w
        ).permute(3, 2, 0, 1)
        if bias:
            self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(name_tf, self.b)

    def load_HyperBNs(self, name_pth, name_tf):
        self.load_HyperBN(name_pth + "HyperBN", os.path.join(name_tf, "HyperBN"))
        self.load_HyperBN(name_pth + "HyperBN_1", os.path.join(name_tf, "HyperBN_1"))

    def load_ScaledCrossReplicaBNs(self, name_pth, name_tf):
        self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(
            name_tf, self.beta
        ).squeeze()
        self.state_dict[name_pth + ".weight"] = self.load_tf_tensor(
            name_tf, self.gamma
        ).squeeze()
        self.state_dict[name_pth + ".running_mean"] = self.load_tf_tensor(
            name_tf + "bn", "accumulated_mean"
        )
        self.state_dict[name_pth + ".running_var"] = self.load_tf_tensor(
            name_tf + "bn", "accumulated_var"
        )
        self.state_dict[name_pth + ".num_batches_tracked"] = torch.tensor(
            self.tf_weights[os.path.join(name_tf + "bn", "accumulation_counter:0")][()],
            dtype=torch.float32,
        )

    def load_HyperBN(self, name_pth, name_tf):
        if self.verbose:
            print(f"loading: {name_pth} from {name_tf}")
        beta = name_pth + ".beta_embed.module"
        gamma = name_pth + ".gamma_embed.module"
        self.state_dict[beta + ".weight_u"] = self.load_tf_tensor(
            os.path.join(name_tf, "beta"), self.u
        ).squeeze()
        self.state_dict[gamma + ".weight_u"] = self.load_tf_tensor(
            os.path.join(name_tf, "gamma"), self.u
        ).squeeze()
        self.state_dict[beta + ".weight_v"] = self.load_tf_tensor(
            os.path.join(name_tf, "beta"), self.v
        ).squeeze()
        self.state_dict[gamma + ".weight_v"] = self.load_tf_tensor(
            os.path.join(name_tf, "gamma"), self.v
        ).squeeze()
        self.state_dict[beta + ".weight_bar"] = self.load_tf_tensor(
            os.path.join(name_tf, "beta"), self.w
        ).permute(1, 0)
        self.state_dict[gamma + ".weight_bar"] = self.load_tf_tensor(
            os.path.join(name_tf, "gamma"), self.w
        ).permute(1, 0)

        cr_bn_name = name_tf.replace("HyperBN", "CrossReplicaBN")
        self.state_dict[name_pth + ".bn.running_mean"] = self.load_tf_tensor(
            cr_bn_name, "accumulated_mean"
        )
        self.state_dict[name_pth + ".bn.running_var"] = self.load_tf_tensor(
            cr_bn_name, "accumulated_var"
        )
        self.state_dict[name_pth + ".bn.num_batches_tracked"] = torch.tensor(
            self.tf_weights[os.path.join(cr_bn_name, "accumulation_counter:0")][()],
            dtype=torch.float32,
        )

    def load_attention(self, name_pth, name_tf):

        self.load_snconv(name_pth + "theta", os.path.join(name_tf, "theta"), bias=False)
        self.load_snconv(name_pth + "phi", os.path.join(name_tf, "phi"), bias=False)
        self.load_snconv(name_pth + "g", os.path.join(name_tf, "g"), bias=False)
        self.load_snconv(
            name_pth + "o_conv", os.path.join(name_tf, "o_conv"), bias=False
        )
        self.state_dict[name_pth + "gamma"] = self.load_tf_tensor(name_tf, self.gamma)

    def load_tf_tensor(self, prefix, var, device="0"):
        name = os.path.join(prefix, var) + f":{device}"
        return torch.from_numpy(self.tf_weights[name][:])


# Convert from v1: This function maps
def convert_from_v1(hub_dict, resolution=128):
    weightname_dict = {"weight_u": "u0", "weight_bar": "weight", "bias": "bias"}
    convnum_dict = {"conv0": "conv1", "conv1": "conv2", "conv_sc": "conv_sc"}
    attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution]
    hub2me = {
        "linear.weight": "shared.weight",  # This is actually the shared weight
        # Linear stuff
        "G_linear.module.weight_bar": "linear.weight",
        "G_linear.module.bias": "linear.bias",
        "G_linear.module.weight_u": "linear.u0",
        # output layer stuff
        "ScaledCrossReplicaBN.weight": "output_layer.0.gain",
        "ScaledCrossReplicaBN.bias": "output_layer.0.bias",
        "ScaledCrossReplicaBN.running_mean": "output_layer.0.stored_mean",
        "ScaledCrossReplicaBN.running_var": "output_layer.0.stored_var",
        "colorize.module.weight_bar": "output_layer.2.weight",
        "colorize.module.bias": "output_layer.2.bias",
        "colorize.module.weight_u": "output_layer.2.u0",
        # Attention stuff
        "attention.gamma": "blocks.%d.1.gamma" % attention_blocknum,
        "attention.theta.module.weight_u": "blocks.%d.1.theta.u0" % attention_blocknum,
        "attention.theta.module.weight_bar": "blocks.%d.1.theta.weight"
        % attention_blocknum,
        "attention.phi.module.weight_u": "blocks.%d.1.phi.u0" % attention_blocknum,
        "attention.phi.module.weight_bar": "blocks.%d.1.phi.weight"
        % attention_blocknum,
        "attention.g.module.weight_u": "blocks.%d.1.g.u0" % attention_blocknum,
        "attention.g.module.weight_bar": "blocks.%d.1.g.weight" % attention_blocknum,
        "attention.o_conv.module.weight_u": "blocks.%d.1.o.u0" % attention_blocknum,
        "attention.o_conv.module.weight_bar": "blocks.%d.1.o.weight"
        % attention_blocknum,
    }

    # Loop over the hub dict and build the hub2me map
    for name in hub_dict.keys():
        if "GBlock" in name:
            if "HyperBN" not in name:  # it's a conv
                out = parse.parse("GBlock.{:d}.{}.module.{}", name)
                blocknum, convnum, weightname = out
                if weightname not in weightname_dict:
                    continue  # else hyperBN in
                out_name = "blocks.%d.0.%s.%s" % (
                    blocknum,
                    convnum_dict[convnum],
                    weightname_dict[weightname],
                )  # Increment conv number by 1
            else:  # hyperbn not conv
                BNnum = 2 if "HyperBN_1" in name else 1
                if "embed" in name:
                    out = parse.parse("GBlock.{:d}.{}.module.{}", name)
                    blocknum, gamma_or_beta, weightname = out
                    if weightname not in weightname_dict:  # Ignore weight_v
                        continue
                    out_name = "blocks.%d.0.bn%d.%s.%s" % (
                        blocknum,
                        BNnum,
                        "gain" if "gamma" in gamma_or_beta else "bias",
                        weightname_dict[weightname],
                    )
                else:
                    out = parse.parse("GBlock.{:d}.{}.bn.{}", name)
                    blocknum, dummy, mean_or_var = out
                    if "num_batches_tracked" in mean_or_var:
                        continue
                    out_name = "blocks.%d.0.bn%d.%s" % (
                        blocknum,
                        BNnum,
                        "stored_mean" if "mean" in mean_or_var else "stored_var",
                    )
            hub2me[name] = out_name

    # Invert the hub2me map
    me2hub = {hub2me[item]: item for item in hub2me}
    new_dict = {}
    dimz_dict = {128: 20, 256: 20, 512: 16}
    for item in me2hub:
        # Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs
        if (
            ("bn" in item and "weight" in item)
            and ("gain" in item or "bias" in item)
            and ("output_layer" not in item)
        ):
            new_dict[item] = torch.cat(
                [
                    hub_dict[me2hub[item]][:, -128:],
                    hub_dict[me2hub[item]][:, : dimz_dict[resolution]],
                ],
                1,
            )
        # Reshape the first linear weight, bias, and u0
        elif item == "linear.weight":
            new_dict[item] = (
                hub_dict[me2hub[item]]
                .contiguous()
                .view(4, 4, 96 * 16, -1)
                .permute(2, 0, 1, 3)
                .contiguous()
                .view(-1, dimz_dict[resolution])
            )
        elif item == "linear.bias":
            new_dict[item] = (
                hub_dict[me2hub[item]]
                .view(4, 4, 96 * 16)
                .permute(2, 0, 1)
                .contiguous()
                .view(-1)
            )
        elif item == "linear.u0":
            new_dict[item] = (
                hub_dict[me2hub[item]]
                .view(4, 4, 96 * 16)
                .permute(2, 0, 1)
                .contiguous()
                .view(1, -1)
            )
        elif (
            me2hub[item] == "linear.weight"
        ):  # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER
            # Transpose shared weight so that it's an embedding
            new_dict[item] = hub_dict[me2hub[item]].t()
        elif "weight_u" in me2hub[item]:  # Unsqueeze u0s
            new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0)
        else:
            new_dict[item] = hub_dict[me2hub[item]]
    return new_dict


def get_config(resolution):
    attn_dict = {128: "64", 256: "128", 512: "64"}
    dim_z_dict = {128: 120, 256: 140, 512: 128}
    config = {
        "G_param": "SN",
        "D_param": "SN",
        "G_ch": 96,
        "D_ch": 96,
        "D_wide": True,
        "G_shared": True,
        "shared_dim": 128,
        "dim_z": dim_z_dict[resolution],
        "hier": True,
        "cross_replica": False,
        "mybn": False,
        "G_activation": nn.ReLU(inplace=True),
        "G_attn": attn_dict[resolution],
        "norm_style": "bn",
        "G_init": "ortho",
        "skip_init": True,
        "no_optim": True,
        "G_fp16": False,
        "G_mixed_precision": False,
        "accumulate_stats": False,
        "num_standing_accumulations": 16,
        "G_eval_mode": True,
        "BN_eps": 1e-04,
        "SN_eps": 1e-04,
        "num_G_SVs": 1,
        "num_G_SV_itrs": 1,
        "resolution": resolution,
        "n_classes": 1000,
    }
    return config


def convert_biggan(
    resolution, weight_dir, redownload=False, no_ema=False, verbose=False
):
    module_path = MODULE_PATH_TMPL.format(resolution)
    hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution))
    pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution))

    tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload)
    G_temp = getattr(biggan_for_conversion, f"Generator{resolution}")()
    state_dict_temp = G_temp.state_dict()

    converter = TFHub2Pytorch(
        state_dict_temp,
        tf_weights,
        resolution=resolution,
        load_ema=(not no_ema),
        verbose=verbose,
    )
    state_dict_v1 = converter.load()
    state_dict = convert_from_v1(state_dict_v1, resolution)
    # Get the config, build the model
    config = get_config(resolution)
    G = BigGAN.Generator(**config)
    G.load_state_dict(state_dict, strict=False)  # Ignore missing sv0 entries
    torch.save(state_dict, pth_path)

    # output_location ='pretrained_weights/TFHub-PyTorch-128.pth'

    return G


def generate_sample(G, z_dim, batch_size, filename, parallel=False):

    G.eval()
    G.to(DEVICE)
    with torch.no_grad():
        z = torch.randn(batch_size, G.dim_z).to(DEVICE)
        y = torch.randint(
            low=0,
            high=1000,
            size=(batch_size,),
            device=DEVICE,
            dtype=torch.int64,
            requires_grad=False,
        )
        if parallel:
            images = nn.parallel.data_parallel(G, (z, G.shared(y)))
        else:
            images = G(z, G.shared(y))
    save_image(images, filename, scale_each=True, normalize=True)


def parse_args():
    usage = "Parser for conversion script."
    parser = argparse.ArgumentParser(description=usage)
    parser.add_argument(
        "--resolution",
        "-r",
        type=int,
        default=None,
        choices=[128, 256, 512],
        help="Resolution of TFHub module to convert. Converts all resolutions if None.",
    )
    parser.add_argument(
        "--redownload",
        action="store_true",
        default=False,
        help="Redownload weights and overwrite current hdf5 file, if present.",
    )
    parser.add_argument("--weights_dir", type=str, default="pretrained_weights")
    parser.add_argument("--samples_dir", type=str, default="pretrained_samples")
    parser.add_argument(
        "--no_ema", action="store_true", default=False, help="Do not load ema weights."
    )
    parser.add_argument(
        "--verbose", action="store_true", default=False, help="Additionally logging."
    )
    parser.add_argument(
        "--generate_samples",
        action="store_true",
        default=False,
        help="Generate test sample with pretrained model.",
    )
    parser.add_argument(
        "--batch_size", type=int, default=64, help="Batch size used for test sample."
    )
    parser.add_argument(
        "--parallel", action="store_true", default=False, help="Parallelize G?"
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":

    args = parse_args()
    os.makedirs(args.weights_dir, exist_ok=True)
    os.makedirs(args.samples_dir, exist_ok=True)

    if args.resolution is not None:
        G = convert_biggan(
            args.resolution,
            args.weights_dir,
            redownload=args.redownload,
            no_ema=args.no_ema,
            verbose=args.verbose,
        )
        if args.generate_samples:
            filename = os.path.join(
                args.samples_dir, f"biggan{args.resolution}_samples.jpg"
            )
            print("Generating samples...")
            generate_sample(
                G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel
            )
    else:
        for res in RESOLUTIONS:
            G = convert_biggan(
                res,
                args.weights_dir,
                redownload=args.redownload,
                no_ema=args.no_ema,
                verbose=args.verbose,
            )
            if args.generate_samples:
                filename = os.path.join(args.samples_dir, f"biggan{res}_samples.jpg")
                print("Generating samples...")
                generate_sample(
                    G, Z_DIMS[res], args.batch_size, filename, args.parallel
                )
